from typing import Tuple, List

import jax.numpy as jnp
import jax
from functools import partial

from common import Batch, InfoDict, Model, Params, PRNGKey


def chi_square_loss(diff, alpha, args=None):
    loss = alpha*jnp.maximum(diff+diff**2/4,0) - (1-alpha)*diff
    return loss

def total_variation_loss(diff, alpha, args=None):
    loss = alpha*jnp.maximum(diff,0) - (1-alpha)*diff
    return loss

def recoil_loss(diff, alpha, args=None):
    loss = jnp.minimum(jnp.exp(alpha * diff), 100) + alpha*jnp.maximum(diff,0)
    return loss

def reverse_kl_loss(diff, alpha, args=None):
    """ Gumbel loss J: E[e^x - x - 1]. For stability to outliers, we scale the gradients with the max value over a batch
    and optionally clip the exponent. This has the effect of training with an adaptive lr.
    """
    z = diff/alpha
    if args.max_clip is not None:
        z = jnp.minimum(z, args.max_clip) # clip max value
    max_z = jnp.max(z, axis=0)
    max_z = jnp.where(max_z < -1.0, -1.0, max_z)
    max_z = jax.lax.stop_gradient(max_z)  # Detach the gradients
    loss = jnp.exp(z - max_z) - z*jnp.exp(-max_z) - jnp.exp(-max_z)  # scale by e^max_z
    return loss

def expectile_loss(diff, expectile=0.8):
    weight = jnp.where(diff > 0, expectile, (1 - expectile))
    return weight * (diff**2)
 

def update_v(critic: Model, value: Model, batch: Batch,
             expectile: float, loss_temp: float, double: bool, vanilla: bool, key: PRNGKey, args) -> Tuple[Model, InfoDict]:
    actions = batch.actions
    
    rng1, rng2 = jax.random.split(key)
    if args.sample_random_times > 0:
        # add random actions to smooth loss computation (use 1/2(rho + Unif))
        times = args.sample_random_times
        random_action = jax.random.uniform(
            rng1, shape=(times * actions.shape[0],
                         actions.shape[1]),
            minval=-1.0, maxval=1.0)
        obs = jnp.concatenate([batch.observations, jnp.repeat(
            batch.observations, times, axis=0)], axis=0)
        acts = jnp.concatenate([batch.actions, random_action], axis=0)
    else:
        obs = batch.observations
        acts = batch.actions

    if args.noise:
        std = args.noise_std
        noise = jax.random.normal(rng2, shape=(acts.shape[0], acts.shape[1]))
        noise = jnp.clip(noise * std, -0.5, 0.5)
        acts = (batch.actions + noise)
        acts = jnp.clip(acts, -1, 1)

    q1, q2 = critic(obs, acts)
    if double:
        q = jnp.minimum(q1, q2)
    else:
        q = q1

    def value_loss_fn(value_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        v = value.apply({'params': value_params}, obs)

        fs = args.f.split('+')
        value_loss = 0
        for f in fs:
            if False and f == 'recoil':
                value_loss += recoil_loss(q - v, alpha=loss_temp, args=args).mean()
            elif f=='chi-square':
                value_loss += chi_square_loss(q - v, alpha=loss_temp, args=args).mean()
            elif f=='total-variation':
                value_loss += total_variation_loss(q - v, alpha=loss_temp, args=args).mean()
            elif f == 'recoil' or f=='reverse-kl': # Same as XQL
                value_loss += reverse_kl_loss(q - v, alpha=loss_temp, args=args).mean()
 
        return value_loss, {
            'value_loss': value_loss,
            'v': v.mean(),
        }
  
    new_value, info = value.apply_gradient(value_loss_fn)

    return new_value, info


def update_q(critic: Model, target_value: Model, expert_batch: Batch, suboptimal_batch: Batch, mix_batch: Batch, 
             discount: float, double: bool, key: PRNGKey, loss_temp: float, args) -> Tuple[Model, InfoDict]:
    
    def critic_loss_fn(critic_params: Params) -> Tuple[jnp.ndarray, InfoDict]:
        mix_next_v = target_value(mix_batch.next_observations)
        mix_target_q = mix_batch.rewards + discount * mix_batch.masks * mix_next_v

        mix_acts = mix_batch.actions
        mix_q1, mix_q2 = critic.apply({'params': critic_params}, mix_batch.observations, mix_acts)
        mix_v = target_value(mix_batch.observations)

        def mse_loss(q, q_target, *args):
            loss_dict = {}

            x = q-q_target
            loss = huber_loss(x, delta=20.0)  # x**2
            loss_dict['critic_loss'] = loss.mean()

            return loss.mean(), loss_dict
 

        if double:
            loss1, dict1 = mse_loss(mix_q1, mix_target_q, mix_v, loss_temp)
            loss2, dict2 = mse_loss(mix_q2, mix_target_q, mix_v, loss_temp)

            critic_loss = (loss1 + loss2).mean()
            for k, v in dict2.items():
                dict1[k] += v
            loss_dict = dict1
        else:
            # critic_loss, loss_dict = dual_q_loss(q1, target_q, v, loss_temp)
            critic_loss, loss_dict = mse_loss(mix_q1, mix_target_q,  v, loss_temp)
        if 'recoil' in args.f:
            # Use the config in the paper: 1 million transitions sub-optimal trajs vs 30 expert demonstrations ==> beta = 30/1000000
            expert_acts = expert_batch.actions
            expert_q1, expert_q2 = critic.apply({'params': critic_params}, expert_batch.observations, expert_acts)
            expert_loss1, expert_dict1 = mse_loss(expert_q1, 200)
            expert_loss2, expert_dict2 = mse_loss(expert_q2, 200)
            expert_loss = (expert_loss1 + expert_loss2).mean()

            expert_dict = {('recoil_' + k): v for k, v in expert_dict1.items()}
            for k, v in expert_dict2.items():
                expert_dict['recoil_' + k] += v

            loss_dict.update(expert_dict)

            suboptimal_acts = suboptimal_batch.actions
            suboptimal_q1, suboptimal_q2 = critic.apply({'params': critic_params}, suboptimal_batch.observations, suboptimal_acts)
            suboptimal_loss = jnp.maximum(suboptimal_q1, suboptimal_q2).mean() 
            
            recoil_loss = 4 * 0.9 *  (suboptimal_loss + expert_loss)
            critic_loss += recoil_loss

        if args.grad_pen:
            lambda_ = args.lambda_gp
            q1_grad, q2_grad = grad_norm(critic, critic_params, mix_batch.observations, mix_acts)
            loss_dict['q1_grad'] = q1_grad.mean()
            loss_dict['q2_grad'] = q2_grad.mean()

            if double:
                gp_loss = (q1_grad + q2_grad).mean()
            else:
                gp_loss = q1_grad.mean()

            critic_loss += lambda_ * gp_loss

        loss_dict.update({
            'q1': mix_q1.mean(),
            'q2': mix_q2.mean()
        })

        return critic_loss, loss_dict

    new_critic, info = critic.apply_gradient(critic_loss_fn)

    return new_critic, info


def grad_norm(model, params, obs, action, lambda_=10):

    @partial(jax.vmap, in_axes=(0, 0))
    @partial(jax.jacrev, argnums=1)
    def input_grad_fn(obs, action):
        return model.apply({'params': params}, obs, action)

    def grad_pen_fn(grad):
        # We use gradient penalties inspired from WGAN-LP loss which penalizes grad_norm > 1
        penalty = jnp.maximum(jnp.linalg.norm(grad1, axis=-1) - 1, 0)**2
        return penalty

    grad1, grad2 = input_grad_fn(obs, action)

    return grad_pen_fn(grad1), grad_pen_fn(grad2)


def huber_loss(x, delta: float = 1.):
    """Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
    See "Robust Estimation of a Location Parameter" by Huber.
    (https://projecteuclid.org/download/pdf_1/euclid.aoms/1177703732).
    Args:
    x: a vector of arbitrary shape.
    delta: the bounds for the huber loss transformation, defaults at 1.
    Note `grad(huber_loss(x))` is equivalent to `grad(0.5 * clip_gradient(x)**2)`.
    Returns:
    a vector of same shape of `x`.
    """
    # 0.5 * x^2                  if |x| <= d
    # 0.5 * d^2 + d * (|x| - d)  if |x| > d
    abs_x = jnp.abs(x)
    quadratic = jnp.minimum(abs_x, delta)
    # Same as max(abs_x - delta, 0) but avoids potentially doubling gradient.
    linear = abs_x - quadratic
    return 0.5 * quadratic**2 + delta * linear
